/**
 * @File:   ray_voxel_intersection.cu
 * @Author: NVIDIA Corporation
 * @Date:   2021-10-13 00:00:00
 * @Last Modified by: Haozhe Xie
 * @Last Modified at: 2024-03-27 11:02:41
 * @Email:  root@haozhexie.com
 */

#include <torch/types.h>

#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/CUDAContext.h>

#include <cuda.h>
#include <cuda_runtime.h>
#include <curand.h>
#include <curand_kernel.h>
#include <time.h>

//#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <vector>

#include "voxlib_common.h"

struct RVIP_Params {
  int voxel_dims[3];
  int voxel_strides[3];
  int max_samples;
  int img_dims[2];
  // Camera parameters
  float cam_ori[3];
  float cam_fwd[3];
  float cam_side[3];
  float cam_up[3];
  float cam_c[2];
  float cam_f;
  // unsigned long seed;
};

/*
  out_voxel_id: torch CUDA int32  [   img_dims[0], img_dims[1], max_samples, 1]
  out_depth:    torch CUDA float  [2, img_dims[0], img_dims[1], max_samples, 1]
  out_raydirs:  torch CUDA float  [   img_dims[0], img_dims[1],           1, 3]
  Image coordinates refer to the center of the pixel [0, 0, 0] at voxel
  coordinate is at the corner of the corner block (instead of at the center)
*/
template <int TILE_DIM>
static __global__ void ray_voxel_intersection_perspective_kernel(
    int32_t *__restrict__ out_voxel_id, float *__restrict__ out_depth,
    float *__restrict__ out_raydirs, const int32_t *__restrict__ in_voxel,
    const RVIP_Params p) {

  int img_coords[2];
  img_coords[1] = blockIdx.x * TILE_DIM + threadIdx.x;
  img_coords[0] = blockIdx.y * TILE_DIM + threadIdx.y;
  if (img_coords[0] >= p.img_dims[0] || img_coords[1] >= p.img_dims[1]) {
    return;
  }
  int pix_index = img_coords[0] * p.img_dims[1] + img_coords[1];

  // Calculate ray origin and direction
  float rayori[3], raydir[3];
  rayori[0] = p.cam_ori[0];
  rayori[1] = p.cam_ori[1];
  rayori[2] = p.cam_ori[2];

  // Camera intrinsics
  float ndc_imcoords[2];
  ndc_imcoords[0] = p.cam_c[0] - (float)img_coords[0]; // Flip height
  ndc_imcoords[1] = (float)img_coords[1] - p.cam_c[1];

  raydir[0] = p.cam_up[0] * ndc_imcoords[0] + p.cam_side[0] * ndc_imcoords[1] +
              p.cam_fwd[0] * p.cam_f;
  raydir[1] = p.cam_up[1] * ndc_imcoords[0] + p.cam_side[1] * ndc_imcoords[1] +
              p.cam_fwd[1] * p.cam_f;
  raydir[2] = p.cam_up[2] * ndc_imcoords[0] + p.cam_side[2] * ndc_imcoords[1] +
              p.cam_fwd[2] * p.cam_f;
  normalize<float, 3>(raydir);

  // Save out_raydirs
  out_raydirs[pix_index * 3] = raydir[0];
  out_raydirs[pix_index * 3 + 1] = raydir[1];
  out_raydirs[pix_index * 3 + 2] = raydir[2];

  float axis_t[3];
  int axis_int[3];
  // int axis_intbound[3];

  // Current voxel
  axis_int[0] = floorf(rayori[0]);
  axis_int[1] = floorf(rayori[1]);
  axis_int[2] = floorf(rayori[2]);

#pragma unroll
  for (int i = 0; i < 3; i++) {
    if (raydir[i] > 0) {
      // Initial t value
      // Handle boundary case where rayori[i] is a whole number. Always round Up
      // for the next block
      // axis_t[i] = (ceilf(nextafterf(rayori[i], HUGE_VALF)) - rayori[i]) /
      // raydir[i];
      axis_t[i] = ((float)(axis_int[i] + 1) - rayori[i]) / raydir[i];
    } else if (raydir[i] < 0) {
      axis_t[i] = ((float)axis_int[i] - rayori[i]) / raydir[i];
    } else {
      axis_t[i] = HUGE_VALF;
    }
  }

  // Fused raymarching and sampling
  bool quit = false;
  for (int cur_plane = 0; cur_plane < p.max_samples;
       cur_plane++) { // Last cycle is for calculating p2
    float t = nanf("0");
    float t2 = nanf("0");
    int32_t blk_id = 0;
    // Find the next intersection
    while (!quit) {
      // Find the next smallest t
      float tnow;
      // Hand unroll
      if (axis_t[0] <= axis_t[1] && axis_t[0] <= axis_t[2]) {
        // Update current t
        tnow = axis_t[0];
        // Update t candidates
        if (raydir[0] > 0) {
          axis_int[0] += 1;
          if (axis_int[0] >= p.voxel_dims[0]) {
            quit = true;
          }
          axis_t[0] = ((float)(axis_int[0] + 1) - rayori[0]) / raydir[0];
        } else {
          axis_int[0] -= 1;
          if (axis_int[0] < 0) {
            quit = true;
          }
          axis_t[0] = ((float)axis_int[0] - rayori[0]) / raydir[0];
        }
      } else if (axis_t[1] <= axis_t[2]) {
        tnow = axis_t[1];
        if (raydir[1] > 0) {
          axis_int[1] += 1;
          if (axis_int[1] >= p.voxel_dims[1]) {
            quit = true;
          }
          axis_t[1] = ((float)(axis_int[1] + 1) - rayori[1]) / raydir[1];
        } else {
          axis_int[1] -= 1;
          if (axis_int[1] < 0) {
            quit = true;
          }
          axis_t[1] = ((float)axis_int[1] - rayori[1]) / raydir[1];
        }
      } else {
        tnow = axis_t[2];
        if (raydir[2] > 0) {
          axis_int[2] += 1;
          if (axis_int[2] >= p.voxel_dims[2]) {
            quit = true;
          }
          axis_t[2] = ((float)(axis_int[2] + 1) - rayori[2]) / raydir[2];
        } else {
          axis_int[2] -= 1;
          if (axis_int[2] < 0) {
            quit = true;
          }
          axis_t[2] = ((float)axis_int[2] - rayori[2]) / raydir[2];
        }
      }

      if (quit) {
        break;
      }

      // Skip empty space
      // Could there be deadlock if the ray direction is away from the world?
      if (axis_int[0] < 0 || axis_int[0] >= p.voxel_dims[0] ||
          axis_int[1] < 0 || axis_int[1] >= p.voxel_dims[1] ||
          axis_int[2] < 0 || axis_int[2] >= p.voxel_dims[2]) {
        continue;
      }

      // Test intersection using voxel grid
      int64_t in_voxel_idx =
          static_cast<int64_t>(axis_int[0]) * p.voxel_strides[0] +
          static_cast<int64_t>(axis_int[1]) * p.voxel_strides[1] +
          static_cast<int64_t>(axis_int[2]) * p.voxel_strides[2];
      blk_id = in_voxel[in_voxel_idx];
      if (blk_id == 0) {
        continue;
      }

      // Now that there is an intersection
      t = tnow;
      // Calculate t2
      /*
      #pragma unroll
      for (int i=0; i<3; i++) {
          if (axis_t[i] <= axis_t[(i+1)%3] && axis_t[i] <= axis_t[(i+2)%3]) {
              t2 = axis_t[i];
              break;
          }
      }
      */
      // Hand unroll
      if (axis_t[0] <= axis_t[1] && axis_t[0] <= axis_t[2]) {
        t2 = axis_t[0];
      } else if (axis_t[1] <= axis_t[2]) {
        t2 = axis_t[1];
      } else {
        t2 = axis_t[2];
      }
      break;
    } // while !quit (ray marching loop)

    out_depth[pix_index * p.max_samples + cur_plane] = t;
    out_depth[p.img_dims[0] * p.img_dims[1] * p.max_samples +
              pix_index * p.max_samples + cur_plane] = t2;
    out_voxel_id[pix_index * p.max_samples + cur_plane] = blk_id;
  } // cur_plane
}

/*
  out:
   out_voxel_id: torch CUDA int32  [   img_dims[0], img_dims[1], max_samples, 1]
   out_depth:    torch CUDA float  [2, img_dims[0], img_dims[1], max_samples, 1]
   out_raydirs:  torch CUDA float  [   img_dims[0], img_dims[1], 1, 3]
  in:
   in_voxel:     torch CUDA int32  [X, Y, Z] [40, 512, 512]
   cam_ori:      torch float [3]
   cam_dir:      torch float [3]
   cam_up:       torch float [3]
   cam_f:        float
   cam_c:        int    [2]
   img_dims:     int    [2]
   max_samples:  int
*/
std::vector<torch::Tensor> ray_voxel_intersection_perspective_cuda(
    const torch::Tensor &in_voxel, const torch::Tensor &cam_ori,
    const torch::Tensor &cam_dir, const torch::Tensor &cam_up, float cam_f,
    const std::vector<float> &cam_c, const std::vector<int> &img_dims,
    int max_samples) {
  CHECK_CUDA(in_voxel);

  int curDevice = -1;
  cudaGetDevice(&curDevice);
  cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
  torch::Device device = in_voxel.device();

  assert(in_voxel.dtype() == torch::kInt32); // Minecraft compatibility
  assert(in_voxel.dim() == 3);
  assert(cam_ori.dtype() == torch::kFloat32);
  assert(cam_ori.numel() == 3);
  assert(cam_dir.dtype() == torch::kFloat32);
  assert(cam_dir.numel() == 3);
  assert(cam_up.dtype() == torch::kFloat32);
  assert(cam_up.numel() == 3);
  assert(img_dims.size() == 2);

  RVIP_Params p;

  // Calculate camera rays
  const torch::Tensor cam_ori_c = cam_ori.cpu();
  const torch::Tensor cam_dir_c = cam_dir.cpu();
  const torch::Tensor cam_up_c = cam_up.cpu();

  // Get the coordinate frame of camera space in world space
  normalize<float, 3>(p.cam_fwd, cam_dir_c.data_ptr<float>());
  cross<float>(p.cam_side, p.cam_fwd, cam_up_c.data_ptr<float>());
  normalize<float, 3>(p.cam_side);
  cross<float>(p.cam_up, p.cam_side, p.cam_fwd);
  normalize<float, 3>(p.cam_up); // Not absolutely necessary as both vectors are
                                 // normalized. But just in case...

  copyarr<float, 3>(p.cam_ori, cam_ori_c.data_ptr<float>());

  p.cam_f = cam_f;
  p.cam_c[0] = cam_c[0];
  p.cam_c[1] = cam_c[1];
  p.max_samples = max_samples;

  p.voxel_dims[0] = in_voxel.size(0);
  p.voxel_dims[1] = in_voxel.size(1);
  p.voxel_dims[2] = in_voxel.size(2);
  p.voxel_strides[0] = in_voxel.stride(0);
  p.voxel_strides[1] = in_voxel.stride(1);
  p.voxel_strides[2] = in_voxel.stride(2);

  p.img_dims[0] = img_dims[0];
  p.img_dims[1] = img_dims[1];

  // Create output tensors
  // For Minecraft Seg Mask
  torch::Tensor out_voxel_id =
      torch::empty({p.img_dims[0], p.img_dims[1], p.max_samples, 1},
                   torch::TensorOptions().dtype(torch::kInt32).device(device));

  torch::Tensor out_depth;
  // Produce two sets of localcoords, one for entry point, the other one for
  // exit point. They share the same corner_ids.
  out_depth = torch::empty(
      {2, p.img_dims[0], p.img_dims[1], p.max_samples, 1},
      torch::TensorOptions().dtype(torch::kFloat32).device(device));

  torch::Tensor out_raydirs = torch::empty({p.img_dims[0], p.img_dims[1], 1, 3},
                                           torch::TensorOptions()
                                               .dtype(torch::kFloat32)
                                               .device(device)
                                               .requires_grad(false));

  const int TILE_DIM = 8;
  dim3 dimGrid((p.img_dims[1] + TILE_DIM - 1) / TILE_DIM,
               (p.img_dims[0] + TILE_DIM - 1) / TILE_DIM, 1);
  dim3 dimBlock(TILE_DIM, TILE_DIM, 1);

  ray_voxel_intersection_perspective_kernel<TILE_DIM>
      <<<dimGrid, dimBlock, 0, stream>>>(
          out_voxel_id.data_ptr<int32_t>(), out_depth.data_ptr<float>(),
          out_raydirs.data_ptr<float>(), in_voxel.data_ptr<int32_t>(), p);

  cudaError_t err = cudaGetLastError();
  if (err != cudaSuccess) {
    printf("Error in ray_voxel_intersection_perspective_kernel: %s\n",
           cudaGetErrorString(err));
  }
  return {out_voxel_id, out_depth, out_raydirs};
}
